#!/usr/bin/python
# Copyright 2014 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

# pylint: disable=W0104,W0106,F0401,R0201

import errno
import optparse
import os.path
import sys

import interface


def _ScriptDir():
  return os.path.dirname(os.path.abspath(__file__))


def _GetDirAbove(dirname):
  """Returns the directory "above" this file containing |dirname| (which must
  also be "above" this file)."""
  path = _ScriptDir()
  while True:
    path, tail = os.path.split(path)
    assert tail
    if tail == dirname:
      return path


def _AddThirdPartyImportPath():
  sys.path.insert(0, os.path.join(_GetDirAbove('mojo'), 'third_party'))


_AddThirdPartyImportPath()
import jinja2

loader = jinja2.FileSystemLoader(_ScriptDir())
jinja_env = jinja2.Environment(loader=loader, keep_trailing_newline=True)


# Accumulate lines of code with varying levels of indentation.
class CodeWriter(object):
  def __init__(self):
    self._lines = []
    self._margin = ''
    self._margin_stack = []

  def __lshift__(self, line):
    self._lines.append((self._margin + line).rstrip())

  def PushMargin(self):
    self._margin_stack.append(self._margin)
    self._margin += '  '

  def PopMargin(self):
    self._margin = self._margin_stack.pop()

  def GetValue(self):
    return '\n'.join(self._lines).rstrip() + '\n'

  def Indent(self):
    return Indent(self)


# Context handler that automatically indents and dedents a CodeWriter
class Indent(object):
  def __init__(self, writer):
    self._writer = writer

  def __enter__(self):
    self._writer.PushMargin()

  def __exit__(self, type_, value, traceback):
    self._writer.PopMargin()


def TemplateFile(name):
  return os.path.join(os.path.dirname(__file__), name)


# Wraps comma separated lists as needed.
# TODO(teravest): Eliminate Wrap() and use "git cl format" when code is checked
# in.
def Wrap(pre, items, post):
  complete = pre + ', '.join(items) + post
  if len(complete) <= 80:
    return [complete]
  lines = [pre]
  indent = '    '
  for i, item in enumerate(items):
    if i < len(items) - 1:
      lines.append(indent + item + ',')
    else:
      lines.append(indent + item + post)
  return lines


def GeneratorWarning():
  return ('// WARNING this file was generated by %s\n// Do not edit by hand.' %
          os.path.basename(__file__))


# Untrusted library which thunks from the public Mojo API to the IRT interface
# implementing the public Mojo API.
def GenerateLibMojo(functions, out):
  template = jinja_env.get_template('libmojo.cc.tmpl')

  code = CodeWriter()

  for f in functions:
    for line in Wrap('%s %s(' % (f.return_type, f.name), f.ParamList(), ') {'):
      code << line

    with code.Indent():
      code << 'struct nacl_irt_mojo* irt_mojo = get_irt_mojo();'
      code << 'if (irt_mojo == NULL)'
      with code.Indent():
        code << 'return MOJO_RESULT_INTERNAL;'
      code << 'return irt_mojo->%s(%s);' % (
          f.name, ', '.join([p.name for p in f.params]))

    code << '}'
    code << ''

  body = code.GetValue()
  text = template.render(
    generator_warning=GeneratorWarning(),
    body=body)
  out.write(text)


# Parameters passed into trusted code are handled differently depending on
# details of the parameter.  ParamImpl instances encapsulate these differences
# and are used to generate the code that transfers parameters across the
# untrusted/trusted boundary.
class ParamImpl(object):
  def __init__(self, param):
    self.param = param

  # Declare whatever variables are needed to handle this particular parameter.
  def DeclareVars(self, code):
    raise NotImplementedError(type(self))

  # Convert the untrusted representation of the parameter into a trusted
  # representation, such as a scalar value or a trusted pointer into the
  # untrusted address space.
  def ConvertParam(self):
    raise NotImplementedError(type(self))

  # For this particular parameter, what expression should be passed when
  # invoking the trusted Mojo API function?
  def CallParam(self):
    raise NotImplementedError(type(self))

  # After invoking the trusted Mojo API function, transfer data back into
  # untrusted memory.  Overriden for Out and InOut parameters.
  def CopyOut(self, code):
    pass

  # Converting array parameters needs to be defered until after the scalar
  # parameter containing the size of the array has itself been converted.
  def IsArray(self):
    return False


class ScalarInputImpl(ParamImpl):
  def DeclareVars(self, code):
    code << '%s %s_value;' % (self.param.base_type, self.param.name)

  def ConvertParam(self):
    p = self.param
    return ('ConvertScalarInput(nap, params[%d], &%s_value)' %
            (p.uid + 1, p.name))

  def CallParam(self):
    return '%s_value' % self.param.name


class ScalarOutputImpl(ParamImpl):
  def DeclareVars(self, code):
    code << '%s volatile* %s_ptr;' % (self.param.base_type, self.param.name)
    code << '%s %s_value;' % (self.param.base_type, self.param.name)

  def ConvertParam(self):
    p = self.param
    return ('ConvertScalarOutput(nap, params[%d], %s, &%s_ptr)' %
            (p.uid + 1, CBool(p.is_optional), p.name))

  def CallParam(self):
    name = self.param.name
    expr = '&%s_value' % name
    if self.param.is_optional:
      expr = '%s_ptr ? %s : NULL' % (name, expr)
    return expr

  def CopyOut(self, code):
    name = self.param.name
    if self.param.is_struct:
      # C++ errors when you try to copy a volatile struct pointer.
      # (There are no default copy constructors for this case.)
      # memcpy instead.
      copy_stmt = ('memcpy_volatile_out(%s_ptr, &%s_value, sizeof(%s));' %
                   (name, name, self.param.base_type))
    else:
      copy_stmt = '*%s_ptr = %s_value;' % (name, name)

    if self.param.is_optional:
      code << 'if (%s_ptr != NULL) {' % (name)
      with code.Indent():
        code << copy_stmt
      code << '}'
    else:
      code << copy_stmt


class ScalarInOutImpl(ParamImpl):
  def DeclareVars(self, code):
    code << '%s volatile* %s_ptr;' % (self.param.base_type, self.param.name)
    code << '%s %s_value;' % (self.param.base_type, self.param.name)

  def ConvertParam(self):
    p = self.param
    return ('ConvertScalarInOut(nap, params[%d], %s, &%s_value, &%s_ptr)' %
            (p.uid + 1, CBool(p.is_optional), p.name, p.name))

  def CallParam(self):
    name = self.param.name
    expr = '&%s_value' % name
    if self.param.is_optional:
      expr = '%s_ptr ? %s : NULL' % (name, expr)
    return expr

  def CopyOut(self, code):
    name = self.param.name
    if self.param.is_optional:
      code << 'if (%s_ptr != NULL) {' % (name)
      with code.Indent():
        code << '*%s_ptr = %s_value;' % (name, name)
      code << '}'
    else:
      code << '*%s_ptr = %s_value;' % (name, name)


class ArrayImpl(ParamImpl):
  def DeclareVars(self, code):
    code << '%s %s;' % (self.param.param_type, self.param.name)

  def ConvertParam(self):
    p = self.param
    if p.base_type == 'void':
      element_size = '1'
    else:
      element_size = 'sizeof(*%s)' % p.name

    return ('ConvertArray(nap, params[%d], %s, %s, %s, &%s)' %
            (p.uid + 1, p.size + '_value', element_size, CBool(p.is_optional),
             p.name))

  def CallParam(self):
    return self.param.name

  def IsArray(self):
    return True


class ExtensibleStructInputImpl(ParamImpl):
  def DeclareVars(self, code):
    code << '%s %s;' % (self.param.param_type, self.param.name)

  def ConvertParam(self):
    p = self.param
    return ('ConvertExtensibleStructInput(nap, params[%d], %s, &%s)' %
            (p.uid + 1, CBool(p.is_optional), p.name))

  def CallParam(self):
    return self.param.name

def ImplForParam(p):
  if p.IsScalar():
    if p.is_output:
      if p.is_input:
        return ScalarInOutImpl(p)
      else:
        if p.is_always_written:
          return ScalarOutputImpl(p)
        else:
          # Mojo defines that some of its outputs will not be set in specific
          # cases.  To avoid the complexity of determining if the output was set
          # by Mojo, copy the output's current value (possibly junk) and copy it
          # back to untrusted memory afterwards.
          return ScalarInOutImpl(p)
    else:
      return ScalarInputImpl(p)
  elif p.is_array:
    return ArrayImpl(p)
  elif p.is_struct:
    if p.is_input and not p.is_output and p.is_extensible:
      return ExtensibleStructInputImpl(p)
    if not p.is_input and p.is_output and not p.is_extensible:
      return ScalarOutputImpl(p)
  assert False, p.name


def CBool(value):
  return 'true' if value else 'false'


# A trusted wrapper that validates the arguments passed from untrusted code
# before passing them to the underlying public Mojo API.
def GenerateMojoSyscall(functions, out):
  template = jinja_env.get_template('mojo_syscall.cc.tmpl')

  code = CodeWriter()
  code.PushMargin()

  for f in functions:
    impls = [ImplForParam(p) for p in f.params]
    impls.append(ImplForParam(f.result_param))

    code << 'case %d:' % f.uid

    code.PushMargin()

    code << '{'

    with code.Indent():
      num_params = len(f.params) + 2
      code << 'if (num_params != %d) {' % num_params
      with code.Indent():
        code << 'return -1;'
      code << '}'

      # Declare temporaries.
      for impl in impls:
        impl.DeclareVars(code)

      def ConvertParam(code, impl):
        code << 'if (!%s) {' % impl.ConvertParam()
        with code.Indent():
          code << 'return -1;'
        code << '}'

      code << '{'
      with code.Indent():
        code << 'ScopedCopyLock copy_lock(nap);'
        # Convert and validate pointers in two passes.
        # Arrays cannot be validated until the size parameter has been
        # converted.
        for impl in impls:
          if not impl.IsArray():
            ConvertParam(code, impl)
        for impl in impls:
          if impl.IsArray():
            ConvertParam(code, impl)
      code << '}'
      code << ''

      # Call
      getParams = [impl.CallParam() for impl in impls[:-1]]
      code << 'result_value = %s(%s);' % (f.name, ', '.join(getParams))
      code << ''

      # Write outputs
      code << '{'
      with code.Indent():
        code << 'ScopedCopyLock copy_lock(nap);'
        for impl in impls:
          impl.CopyOut(code)
      code << '}'
      code << ''

      code << 'return 0;'
    code << '}'

    code.PopMargin()

  body = code.GetValue()
  text = template.render(
    generator_warning=GeneratorWarning(),
    body=body)
  out.write(text)


# A header declaring the IRT interface for accessing Mojo functions.
def GenerateMojoIrtHeader(functions, out):
  template = jinja_env.get_template('mojo_irt.h.tmpl')
  code = CodeWriter()

  code << 'struct nacl_irt_mojo {'
  with code.Indent():
    for f in functions:
      for line in Wrap('%s (*%s)(' % (f.return_type, f.name),
                       f.ParamList(),
                       ');'):
        code << line

  code << '};'

  body = code.GetValue()

  text = template.render(
    generator_warning=GeneratorWarning(),
    body=body)
  out.write(text)

# IRT interface which implements the Mojo public API.
def GenerateMojoIrtImplementation(functions, out):
  template = jinja_env.get_template('mojo_irt.c.tmpl')
  code = CodeWriter()

  for f in functions:
    for line in Wrap('static %s irt_%s(' % (f.return_type, f.name),
                     f.ParamList(),
                     ') {'):
      code << line

    # 2 extra parameters: message ID and return value.
    num_params = len(f.params) + 2

    with code.Indent():
      code << 'uint32_t params[%d];' % num_params
      return_type = f.result_param.base_type
      if return_type == 'MojoResult':
        default = 'MOJO_RESULT_INVALID_ARGUMENT'
      elif return_type == 'MojoTimeTicks':
        default = '0'
      else:
        raise Exception('Unhandled return type: ' + return_type)
      code << '%s %s = %s;' % (return_type, f.result_param.name, default)

      # Message ID
      code << 'params[0] = %d;' % f.uid
      # Parameter pointers
      cast_template = 'params[%d] = (uint32_t)(%s);'
      for p in f.params:
        ptr = p.name
        if p.IsPassedByValue():
          ptr = '&' + ptr
        code << cast_template % (p.uid + 1, ptr)
      # Return value pointer
      code << cast_template % (num_params - 1, '&' + f.result_param.name)

      code << 'DoMojoCall(params, sizeof(params));'
      code << 'return %s;' % f.result_param.name

    # Add body here.
    code << "};"
    code << "\n"

  # Now we've emitted all the functions, but we still need the struct
  # definition.
  code << 'struct nacl_irt_mojo kIrtMojo = {'
  for f in functions:
    with code.Indent():
      code << '&irt_%s,' % f.name
  code << '};'

  body = code.GetValue()

  text = template.render(
    generator_warning=GeneratorWarning(),
    body=body)
  out.write(text)


def OutFile(dir_path, name):
  if not os.path.exists(dir_path):
    try:
      os.makedirs(dir_path)
    except OSError as e:
      # There may have been a race to create this directory.
      if e.errno != errno.EEXIST:
        raise
  return open(os.path.join(dir_path, name), 'w')


def main(args):
  usage = 'usage: %prog [options]'
  parser = optparse.OptionParser(usage=usage)
  parser.add_option(
      '-d',
      dest='out_dir',
      metavar='DIR',
      help='output generated code into directory DIR')
  options, args = parser.parse_args(args=args)
  if not options.out_dir:
    parser.error('-d is required')
  if args:
    parser.error('unexpected positional arguments: %s' % ' '.join(args))

  mojo = interface.MakeInterface()

  out = OutFile(options.out_dir, 'libmojo.cc')
  GenerateLibMojo(mojo.functions, out)

  out = OutFile(options.out_dir, 'mojo_syscall.cc')
  GenerateMojoSyscall(mojo.functions, out)

  out = OutFile(options.out_dir, 'mojo_irt.h')
  GenerateMojoIrtHeader(mojo.functions, out)

  out = OutFile(options.out_dir, 'mojo_irt.c')
  GenerateMojoIrtImplementation(mojo.functions, out)

if __name__ == '__main__':
  main(sys.argv[1:])
